import torch
import scipy

import numpy as np
import pandas as pd

from pathlib import Path
from sed_scores_eval.utils.scores import create_score_dataframe

"""A gpu-version decoder for strong predictions; This func gives almost identical result as the original DCASE version (error < 0.001)
"""

def batched_decode_preds(
    strong_preds, filenames, encoder, thresholds=[0.5], median_filter=[7] * 10, pad_indx=None,
):
    """ Decode a batch of predictions to dataframes. Each threshold gives a different dataframe and stored in a
    dictionary

    Args:
        strong_preds: torch.Tensor, batch of strong predictions.
        filenames: list, the list of filenames of the current batch.
        encoder: ManyHotEncoder object, object used to decode predictions.
        thresholds: list, the list of thresholds to be used for predictions.
        median_filter: int, the number of frames for which to apply median window (smoothing).
        pad_indx: list, the list of indexes which have been used for padding.

    Returns:
        dict of predictions, each keys is a threshold and the value is the DataFrame of predictions.
    """
    # Init a dataframe per threshold
    scores_raw = {}
    scores_postprocessed = {}
    prediction_dfs = {}
    for threshold in thresholds:
        prediction_dfs[threshold] = pd.DataFrame()

    for j in range(strong_preds.shape[0]):  # over batches
        audio_id = Path(filenames[j]).stem
        filename = audio_id + ".wav"
        c_scores = strong_preds[j]
        if pad_indx is not None:
            true_len = int(c_scores.shape[-1] * pad_indx[j].item())
            c_scores = c_scores[:true_len]
        c_scores = c_scores.transpose(0, 1).detach().cpu().numpy()
        scores_raw[audio_id] = create_score_dataframe(
            scores=c_scores,
            timestamps=encoder._frame_to_time(np.arange(len(c_scores)+1)),
            event_classes=encoder.labels,
        )
        new_scores = []
        for score, fil_val in zip(c_scores.T, median_filter):
            score = scipy.ndimage.filters.median_filter(score[:, np.newaxis], (fil_val, 1))
            new_scores.append(score)
        c_scores = np.concatenate(new_scores, axis=1)
        scores_postprocessed[audio_id] = create_score_dataframe(
            scores=c_scores,
            timestamps=encoder._frame_to_time(np.arange(len(c_scores)+1)),
            event_classes=encoder.labels,
        )
        for c_th in thresholds:
            pred = c_scores > c_th
            pred = encoder.decode_strong(pred)
            pred = pd.DataFrame(pred, columns=["event_label", "onset", "offset"])
            pred["filename"] = filename
            prediction_dfs[c_th] = pd.concat([prediction_dfs[c_th], pred], ignore_index=True)

    return scores_raw, scores_postprocessed, prediction_dfs

def val_decode_preds(strong_preds, thds, median_filter):
    bsz, cls, T = strong_preds.shape
    # Not implement for multiple thds
    if len(thds) > 1:
        smooth_preds = torch.cuda.FloatTensor(len(thds), bsz, cls, T)
    for i, thd in enumerate(thds):
        thd = torch.tensor(thd, device="cuda").reshape(1, 1, 1)
        thd = thd.repeat(bsz, cls, T)
        hard_preds = strong_preds > thd
        if len(thds) > 1:
            smooth_preds[i] = median_filter(hard_preds.float())
        else:
            smooth_preds = median_filter(hard_preds.float())

    return smooth_preds